package com.study.crypto.basic.utils;

import com.study.crypto.basic.asn1.sm2.SM2EnvelopedKey;
import org.bouncycastle.asn1.*;
import org.bouncycastle.asn1.gm.GMNamedCurves;
import org.bouncycastle.asn1.gm.GMObjectIdentifiers;
import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.asn1.sec.SECObjectIdentifiers;
import org.bouncycastle.asn1.x509.AlgorithmIdentifier;
import org.bouncycastle.asn1.x509.Certificate;
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo;
import org.bouncycastle.asn1.x9.X9ECParameters;
import org.bouncycastle.asn1.x9.X9ObjectIdentifiers;
import org.bouncycastle.crypto.engines.SM2Engine;
import org.bouncycastle.crypto.params.ECDomainParameters;
import org.bouncycastle.crypto.params.ECPublicKeyParameters;
import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPrivateKey;
import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey;
import org.bouncycastle.jce.ECNamedCurveTable;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.jce.spec.ECParameterSpec;
import org.bouncycastle.jce.spec.ECPrivateKeySpec;
import org.bouncycastle.jce.spec.ECPublicKeySpec;
import org.bouncycastle.math.ec.ECCurve;
import org.bouncycastle.math.ec.ECPoint;
import org.bouncycastle.openssl.PEMDecryptorProvider;
import org.bouncycastle.openssl.PEMEncryptedKeyPair;
import org.bouncycastle.openssl.PEMKeyPair;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
import org.bouncycastle.openssl.jcajce.JcaPEMWriter;
import org.bouncycastle.openssl.jcajce.JceOpenSSLPKCS8DecryptorProviderBuilder;
import org.bouncycastle.openssl.jcajce.JcePEMDecryptorProviderBuilder;
import org.bouncycastle.operator.InputDecryptorProvider;
import org.bouncycastle.operator.OperatorException;
import org.bouncycastle.pkcs.PKCS8EncryptedPrivateKeyInfo;
import org.bouncycastle.pkcs.PKCSException;
import org.bouncycastle.pqc.math.linearalgebra.ByteUtils;
import org.bouncycastle.util.BigIntegers;
import org.bouncycastle.util.encoders.Base64;
import org.bouncycastle.util.encoders.Hex;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.StringReader;
import java.math.BigInteger;
import java.security.*;

/**
 * @author Songjin
 * @since 2020-12-30 19:00
 */
public final class KeyUtils {
    
    public static final String BC = BouncyCastleProvider.PROVIDER_NAME;
    /** 国密椭圆曲线名称参数 */
    public static final String CURVE_NAME_SM2 = "sm2p256v1";
    public static final String ALGO_SM2 = "SM2";
    public static final String ALGO_RSA = "RSA";
    public static final String ALGO_ECC = "ECC";
    
    public static X9ECParameters sm2p256v1 = GMNamedCurves.getByName(CURVE_NAME_SM2);
    private static final ECParameterSpec ecParameterSpec = new ECParameterSpec(sm2p256v1.getCurve(), sm2p256v1.getG(), sm2p256v1.getN());
    
    static {
        Security.addProvider(new BouncyCastleProvider());
    }
    
    private KeyUtils() {
    }
    
    /**
     * 产生密钥对
     * @param algo 算法类型
     * @return
     * @throws NoSuchProviderException
     * @throws NoSuchAlgorithmException
     * @throws InvalidAlgorithmParameterException
     */
    public static KeyPair generateKeyPair(String algo) throws NoSuchProviderException, NoSuchAlgorithmException, InvalidAlgorithmParameterException {
        KeyPairGenerator keyPairGenerator;
        SecureRandom random = new SecureRandom();
        if (ALGO_RSA.equalsIgnoreCase(algo)) {
            keyPairGenerator = KeyPairGenerator.getInstance(ALGO_RSA, BC);
            keyPairGenerator.initialize(2048, random);
        } else if (ALGO_SM2.equalsIgnoreCase(algo)) {
            ECParameterSpec ecSpec = ECNamedCurveTable.getParameterSpec(CURVE_NAME_SM2);
            keyPairGenerator = KeyPairGenerator.getInstance("ECDSA", BC);
            keyPairGenerator.initialize(ecSpec, random);
        } else if (ALGO_ECC.equalsIgnoreCase(algo)) {
            ECParameterSpec ecSpec = ECNamedCurveTable.getParameterSpec("secp256k1");
            keyPairGenerator = KeyPairGenerator.getInstance("ECDSA", BC);
            keyPairGenerator.initialize(ecSpec, random);
        } else {
            throw new IllegalArgumentException("算法不支持");
        }
        return keyPairGenerator.generateKeyPair();
    }
    
    /**
     * 从证书字节数组中解析获取公钥字节数组，64字节
     * @param certBytes
     * @return
     */
    public static byte[] obtainPublicKeyBytes(byte[] certBytes) {
        Certificate certificate = Certificate.getInstance(certBytes);
        byte[] publicKeyBytes = certificate.getSubjectPublicKeyInfo().getPublicKeyData().getBytes();
        if (publicKeyBytes.length == 65) {
            return ByteUtils.subArray(publicKeyBytes, 1);
        }
        return publicKeyBytes;
    }
    
    /**
     * 获取公钥
     * @param publicKeyInfo
     * @return
     * @throws Exception
     */
    public static PublicKey obtainPublicKey(SubjectPublicKeyInfo publicKeyInfo) throws IOException {
        BouncyCastleProvider bouncyCastleProvider = ((BouncyCastleProvider) Security.getProvider(BC));
        bouncyCastleProvider.addKeyInfoConverter(PKCSObjectIdentifiers.rsaEncryption, new org.bouncycastle.jcajce.provider.asymmetric.rsa.KeyFactorySpi());
        bouncyCastleProvider.addKeyInfoConverter(X9ObjectIdentifiers.id_ecPublicKey,  new org.bouncycastle.jcajce.provider.asymmetric.ec.KeyFactorySpi.EC());
        return BouncyCastleProvider.getPublicKey(publicKeyInfo);
    }
    
    /**
     * 使用字节数组重建公钥
     * @param publicKeyBytes sm2裸公钥字节数组，x、y分量
     * @return 公钥
     */
    public static PublicKey convertPublicKey(byte[] publicKeyBytes) {
        byte[] publicKeyX = new byte[32];
        byte[] publicKeyY = new byte[32];
        System.arraycopy(publicKeyBytes, 0, publicKeyX, 0, 32);
        System.arraycopy(publicKeyBytes, 32, publicKeyY, 0, 32);
        ECCurve         curve   = ecParameterSpec.getCurve();
        ECPoint         point   = curve.createPoint(new BigInteger(1, publicKeyX), new BigInteger(1, publicKeyY));
        ECPublicKeySpec keySpec = new ECPublicKeySpec(point, ecParameterSpec);
        return new BCECPublicKey("EC", keySpec, BouncyCastleProvider.CONFIGURATION);
    }
    
    /**
     * 使用字节数组重建私钥
     * @param privateKey sm2裸私钥字节数组
     * @return 私钥
     */
    public static PrivateKey convertPrivateKey(byte[] privateKey) {
        ECPrivateKeySpec ecPrivateKeySpec = new ECPrivateKeySpec(new BigInteger(1, privateKey), ecParameterSpec);
        return new BCECPrivateKey("EC", ecPrivateKeySpec, BouncyCastleProvider.CONFIGURATION);
    }
    
    /**
     * 将公钥、私钥组装成字节数组
     * @param publicKey  sm2公钥
     * @param privateKey sm2私钥
     * @return 公私钥字节数组
     */
    public static byte[] packageKeyPairSM2(BCECPublicKey publicKey, BCECPrivateKey privateKey) {
        byte[] publicKeyBytes_  = ByteUtils.subArray(publicKey.getQ().getEncoded(false), 1);
        byte[] privateKeyBytes_ = privateKey.getD().toByteArray();
        // 去掉公钥、私钥字节数组中，第一个字节可能存在0的符号位情况
        byte[] publicKeyBytes  = BigIntegers.asUnsignedByteArray(new BigInteger(publicKeyBytes_));
        byte[] privateKeyBytes = BigIntegers.asUnsignedByteArray(new BigInteger(privateKeyBytes_));
        return ByteUtils.concatenate(publicKeyBytes, privateKeyBytes);
    }
    
    /**
     * 将目标对象转换为 PEM 格式字符串返回
     * @param target 目标对象
     * @return PEM格式字符串
     */
    public static String toPEMText(Object target) throws IOException {
        ByteArrayOutputStream bos       = new ByteArrayOutputStream();
        JcaPEMWriter          pemWriter = new JcaPEMWriter(new OutputStreamWriter(bos));
        pemWriter.writeObject(target);
        pemWriter.flush();
        pemWriter.close();
        return bos.toString();
    }
    
    /**
     * 从 PEM 字符串中获取并解密私钥
     * @param pemString pem字符串
     * @param password  加密密钥
     * @return 私钥
     * @throws PKCSException
     * @throws OperatorException
     * @throws IOException
     */
    public static PrivateKey privateKey(String pemString, String password) throws PKCSException, OperatorException, IOException {
        return (PrivateKey) parseKey(pemString, password);
    }
    
    /**
     * 从 PEM 字符串中获取私钥
     * @param pemString pem字符串
     * @return 私钥
     * @throws PKCSException
     * @throws OperatorException
     * @throws IOException
     */
    public static PrivateKey privateKey(String pemString) throws PKCSException, OperatorException, IOException {
        return (PrivateKey) parseKey(pemString, null);
    }
    
    /**
     * 从 PEM 字符串中获取公钥
     * @param pemString pem字符串
     * @return 公钥
     * @throws PKCSException
     * @throws OperatorException
     * @throws IOException
     */
    public static PublicKey publicKey(String pemString) throws PKCSException, OperatorException, IOException {
        return (PublicKey) parseKey(pemString, null);
    }
    
    /**
     * Parses a Key instance from a PEM representation.
     * <p>
     * When the provided key is encrypted, the provided pass phrase is applied.
     *
     * @param pemString  a PEM representation of a private key (cannot be null or empty)
     * @param passPhrase optional pass phrase (must be present if the private key is encrypted).
     * @return a  Key instance (never null)
     */
    public static Key parseKey(String pemString, String passPhrase) throws IOException, OperatorException, PKCSException {
        if (passPhrase == null) {
            passPhrase = "";
        }
        StringReader reader    = new StringReader(pemString);
        PEMParser    pemParser = new PEMParser(reader);
        final Object             object    = pemParser.readObject();
        final JcaPEMKeyConverter converter = new JcaPEMKeyConverter().setProvider(BouncyCastleProvider.PROVIDER_NAME);
    
        final KeyPair kp;
    
        if (object instanceof PEMEncryptedKeyPair) {
            // Encrypted key - we will use provided password
            final PEMDecryptorProvider decProv = new JcePEMDecryptorProviderBuilder().build(passPhrase.toCharArray());
            kp = converter.getKeyPair(((PEMEncryptedKeyPair) object).decryptKeyPair(decProv));
        } else if (object instanceof PKCS8EncryptedPrivateKeyInfo) {
            // Encrypted key - we will use provided password
            final PKCS8EncryptedPrivateKeyInfo encryptedInfo  = (PKCS8EncryptedPrivateKeyInfo) object;
            final InputDecryptorProvider provider       = new JceOpenSSLPKCS8DecryptorProviderBuilder().build(passPhrase.toCharArray());
            final PrivateKeyInfo         privateKeyInfo = encryptedInfo.decryptPrivateKeyInfo(provider);
            return converter.getPrivateKey(privateKeyInfo);
        } else if (object instanceof PrivateKeyInfo) {
            return converter.getPrivateKey((PrivateKeyInfo) object);
        } else if (object instanceof SubjectPublicKeyInfo) {
            return converter.getPublicKey((SubjectPublicKeyInfo) object);
        } else {
            // Unencrypted key - no password needed
            kp = converter.getKeyPair((PEMKeyPair) object);
        }
        return kp.getPrivate();
    }
    
    /**
     * 将公钥对象转换为公钥参数对象
     * @param ecPubKey 公钥
     * @return 公钥参数
     */
    public static ECPublicKeyParameters convertPublicKeyToParameters(BCECPublicKey ecPubKey) {
        ECParameterSpec parameterSpec = ecPubKey.getParameters();
        ECDomainParameters domainParameters = new ECDomainParameters(parameterSpec.getCurve(), parameterSpec.getG(),
                parameterSpec.getN(), parameterSpec.getH());
        return new ECPublicKeyParameters(ecPubKey.getQ(), domainParameters);
    }
    
    /**
     * 将纯公钥字节数组转换为 SubjectPublicKeyInfo
     * @param signOid 签名算法标识
     * @param publicKeyBytes 公钥字节数组
     * @return SubjectPublicKeyInfo
     */
    public static SubjectPublicKeyInfo obtainSubjectPublicKeyInfo(ASN1ObjectIdentifier signOid, byte[] publicKeyBytes)
            throws InvalidAlgorithmParameterException {
        AlgorithmIdentifier signIdentifier;
        if (signOid.equals(GMObjectIdentifiers.sm2sign_with_sm3)) {
            signIdentifier = new AlgorithmIdentifier(X9ObjectIdentifiers.id_ecPublicKey, GMObjectIdentifiers.sm2p256v1);
        } else if (signOid.equals(X9ObjectIdentifiers.ecdsa_with_SHA256)) {
            signIdentifier = new AlgorithmIdentifier(X9ObjectIdentifiers.id_ecPublicKey, SECObjectIdentifiers.secp256k1);
        } else if (signOid.equals(PKCSObjectIdentifiers.sha256WithRSAEncryption)) {
            signIdentifier = new AlgorithmIdentifier(PKCSObjectIdentifiers.rsaEncryption, DERNull.INSTANCE);
        } else {
            throw new InvalidAlgorithmParameterException();
        }
        ASN1EncodableVector subjectKeyInfoVector = new ASN1EncodableVector();
        subjectKeyInfoVector.add(signIdentifier);
        subjectKeyInfoVector.add(new DERBitString(publicKeyBytes));
        return SubjectPublicKeyInfo.getInstance(new DERSequence(subjectKeyInfoVector));
    }

    /**
     * 将纯公钥base64字符串转换为 SubjectPublicKeyInfo，带公钥算法标识
     * @param signOid 签名算法标识
     * @param publicKeyB64 公钥base64编码
     * @return
     */
    public static SubjectPublicKeyInfo obtainSubjectPublicKeyInfoBase64(ASN1ObjectIdentifier signOid, String publicKeyB64)
            throws InvalidAlgorithmParameterException {
        return obtainSubjectPublicKeyInfo(signOid, Base64.decode(publicKeyB64));
    }

    /**
     * 将纯公钥base64字符串转换为 SubjectPublicKeyInfo
     * @param publicKeyBase64 公钥base64编码，含 04 前缀
     * @return SubjectPublicKeyInfo
     */
    public static SubjectPublicKeyInfo obtainSubjectPublicKeyInfoBase64(String publicKeyBase64) throws InvalidAlgorithmParameterException {
        return obtainSubjectPublicKeyInfo(GMObjectIdentifiers.sm2sign_with_sm3, Base64.decode(publicKeyBase64));
    }
    
    /**
     * 产生 SM2 密钥对保护结构，基于《GB-T 35276_2017_信息安全技术_SM2密码算法使用规范》
     * @param publicKeyB64 外部公钥
     * @param keyPair      保护密钥对
     * @return 密钥对保护结构
     */
    public static byte[] generateEnvelopedKey(String publicKeyB64, KeyPair keyPair) throws Exception{
        byte[] publicKeyBytes = Base64.decode(publicKeyB64);
        BCECPublicKey publicKey_ = (BCECPublicKey) convertPublicKey(ByteUtils.subArray(publicKeyBytes, 1));
        byte[] symmetric = SM4Utils.generateKey();
        BCECPrivateKey privateKey = (BCECPrivateKey) keyPair.getPrivate();
        BCECPublicKey publicKey = (BCECPublicKey) keyPair.getPublic();
        // 1.对称密码算法标识
        AlgorithmIdentifier symAlgID = new AlgorithmIdentifier(GMObjectIdentifiers.sms4_ecb, DERNull.INSTANCE);
        // 2.对称密钥的 sm2 密文结构: SM2Cipher
        byte[] encrypt = SM2Utils.encrypt(publicKey_, symmetric, SM2Engine.Mode.C1C3C2);
        ASN1Sequence cipher = ASN1Sequence.getInstance(encrypt);
        // 3.公钥 sm2PublicKey
        DERBitString sm2PublicKey = new DERBitString(publicKey.getQ().getEncoded(false));
        // 4.私钥密文 sm2EncryptedPrivateKey
        byte[] privateKeyBytes = BigIntegers.asUnsignedByteArray(privateKey.getD());
        byte[] encryptedPrivateKey = SM4Utils.encrypt_ecb_nopadding(symmetric, privateKeyBytes);
        DERBitString sm2EncryptedPrivateKey = new DERBitString(encryptedPrivateKey);
        SM2EnvelopedKey envelopedKey = new SM2EnvelopedKey(symAlgID, cipher, sm2PublicKey, sm2EncryptedPrivateKey);
        return envelopedKey.getEncoded(ASN1Encoding.DER);
    }

}